package com.icesoft.core.common.util;

import lombok.extern.slf4j.Slf4j;

import javax.crypto.BadPaddingException;
import javax.crypto.Cipher;
import javax.crypto.IllegalBlockSizeException;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.security.Key;
import java.security.interfaces.RSAKey;
import java.security.interfaces.RSAPrivateKey;
import java.security.interfaces.RSAPublicKey;
import java.util.Base64;

@Slf4j
public class RsaUtil {

    private static final Charset DEFAULT_ENCODING = StandardCharsets.UTF_8;

    /**
     * 公钥加密
     */
    public static String encryptWithBase64(RSAPublicKey publicKey, String string) throws IllegalArgumentException {
        byte[] binaryData = encrypt(publicKey, string.getBytes(DEFAULT_ENCODING));
        return Base64.getEncoder().encodeToString(binaryData);
    }

    /**
     * 私钥解密
     */
    public static String decryptWithBase64(RSAPrivateKey privateKey, String base64String)
            throws IllegalArgumentException {
        byte[] binaryData = decrypt(privateKey, Base64.getDecoder().decode(base64String.getBytes(DEFAULT_ENCODING)));
        return new String(binaryData, DEFAULT_ENCODING);
    }

    /**
     * 私钥加密
     */
    public static String encryptWithBase64(RSAPrivateKey privateKey, String string) throws IllegalArgumentException {
        byte[] binaryData = encrypt(privateKey, string.getBytes(DEFAULT_ENCODING));
        return Base64.getEncoder().encodeToString(binaryData);
    }

    /**
     * 公钥解密
     */
    public static String decryptWithBase64(RSAPublicKey publicKey, String base64String)
            throws IllegalArgumentException {
        byte[] binaryData = decrypt(publicKey, Base64.getDecoder().decode(base64String.getBytes(DEFAULT_ENCODING)));
        return new String(binaryData, DEFAULT_ENCODING);
    }

    private final static int MAX_ENCRYPT_BLOCK = 117;

    public static byte[] encrypt(RSAPublicKey publicKey, byte[] processData) throws IllegalArgumentException {
        if (publicKey == null) {
            throw new IllegalArgumentException("加密公钥为空, 请设置");
        }
        return encryptByKey(publicKey, processData);
    }

    public static byte[] encrypt(RSAPrivateKey privateKey, byte[] processData) throws IllegalArgumentException {
        if (privateKey == null) {
            throw new IllegalArgumentException("加密公钥为空, 请设置");
        }
        return encryptByKey(privateKey, processData);
    }

    public static byte[] decrypt(RSAPrivateKey privateKey, byte[] cipherData) throws IllegalArgumentException {
        if (privateKey == null) {
            throw new IllegalArgumentException("解密私钥为空, 请设置");
        }
        return decrypt(privateKey, privateKey, cipherData);
    }

    public static byte[] decrypt(RSAPublicKey publicKey, byte[] cipherData) throws IllegalArgumentException {
        if (publicKey == null) {
            throw new IllegalArgumentException("解密公钥为空, 请设置");
        }
        return decrypt(publicKey, publicKey, cipherData);
    }

    private static byte[] encryptByKey(Key key, byte[] processData) throws IllegalArgumentException {
        Cipher cipher = getRsaCipher(key, Cipher.ENCRYPT_MODE);
        return encrypt(cipher, processData);
    }

    private static byte[] decrypt(Key key, RSAKey rsaKey, byte[] cipherData) throws IllegalArgumentException {
        Cipher cipher = getRsaCipher(key, Cipher.DECRYPT_MODE);
        int chunkLen = rsaKey.getModulus().bitLength() / 8;
        return decrypt(cipher, chunkLen, cipherData);
    }

    private static final String RSA_TRANSFORMATION = "RSA";

    public static Cipher getRsaCipher(Key key, int mode) {
        return CipherUtils.getCipher(RSA_TRANSFORMATION, key, mode);
    }

    /**
     * 加密过程
     *
     * @param cipher      公钥
     * @param processData 明文数据
     */

    private static byte[] encrypt(Cipher cipher, byte[] processData) {
        int inputLen = processData.length;
        try (ByteArrayOutputStream out = new ByteArrayOutputStream()) {
            // 对数据分段加密
            return getBytes(cipher, processData, inputLen, out, MAX_ENCRYPT_BLOCK);
        } catch (IllegalBlockSizeException e) {
            throw new IllegalArgumentException("密文长度非法", e);
        } catch (IOException e) {
            throw new IllegalArgumentException("IOException", e);
        } catch (BadPaddingException e) {
            throw new IllegalArgumentException("密钥错误", e);
        }

    }

    private static byte[] getBytes(Cipher cipher, byte[] processData, int inputLen, ByteArrayOutputStream out, int maxEncryptBlock) throws IllegalBlockSizeException, BadPaddingException, IOException {
        int offSet = 0;
        int i = 0;
        byte[] cache;
        while (inputLen - offSet > 0) {
            if (inputLen - offSet > maxEncryptBlock) {
                cache = cipher.doFinal(processData, offSet, maxEncryptBlock);
            } else {
                cache = cipher.doFinal(processData, offSet, inputLen - offSet);
            }
            out.write(cache, 0, cache.length);
            i++;
            offSet = i * maxEncryptBlock;
        }
        byte[] encryptedData = out.toByteArray();
        out.close();
        return encryptedData;
    }

    /**
     * 解密过程
     *
     * @param cipher     私钥
     * @param cipherData 密文数据
     * @return 明文
     */
    private static byte[] decrypt(Cipher cipher, int chunkLen, byte[] cipherData) {
        try (ByteArrayOutputStream out = new ByteArrayOutputStream()) {
            int inputLen = cipherData.length;
            // 对数据分段解密
            return getBytes(cipher, cipherData, inputLen, out, chunkLen);
        } catch (IllegalBlockSizeException e) {
            throw new IllegalArgumentException("密文错误", e);
        } catch (IOException e) {
            log.error("解密错误", e);
            throw new IllegalArgumentException("IOException", e);
        } catch (BadPaddingException e) {
            throw new IllegalArgumentException("密钥错误", e);
        }
    }
}